Add API to list coopMat/coopVec types and combinations#10076
Add API to list coopMat/coopVec types and combinations#10076cmarcelo wants to merge 4 commits intoshader-slang:masterfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis pull request adds comprehensive support for cooperative matrix and vector operations across the Slang compiler. It introduces new public API types and metadata interfaces for exposing cooperative type information, extends IR with five new cooperative instructions and validation logic, implements codegen for multiple targets (CUDA OptiX, HLSL, SPIR-V), and adds three new scalar types (BFloat16, FloatE4M3, FloatE5M2) plus pointer-sized integer types. Changes
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Pull request overview
This pull request adds a new API to expose cooperative matrix type metadata, allowing applications to query which cooperative matrix type combinations are used in compiled shaders. This addresses issue #10021, where applications need to verify driver support for specific cooperative matrix type combinations used in shaders.
Changes:
- Adds new public API types (
SlangScope,SlangCooperativeMatrixUse,SlangCooperativeComponentType,SlangCooperativeMatrixType) andICooperativeMatrixMetadatainterface to slang.h - Implements metadata collection that recursively scans IR to identify all cooperative matrix types used in the compiled code
- Includes comprehensive unit tests for both the new cooperative matrix metadata API and the binarySearch changes from dependency PRs
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| include/slang.h | Adds new enums, struct, and interface for querying cooperative matrix types used in compiled shaders |
| source/compiler-core/slang-artifact-associated-impl.h | Extends ArtifactPostEmitMetadata to implement ICooperativeMatrixMetadata interface |
| source/compiler-core/slang-artifact-associated-impl.cpp | Implements the cooperative matrix metadata query methods with proper error handling |
| source/slang/slang-ir-metadata.cpp | Adds IR type conversion helpers and recursive metadata collection for cooperative matrix types |
| source/core/slang-list.h | Updates binarySearch to return bitwise negation of insertion index on miss (from PR #10048) |
| source/slang/slang-check-decl.cpp | Replaces binarySearch with indexOf for declaration ordering (from PR #10047) |
| source/slang/slang-language-server-auto-format.cpp | Updates exclusion range check to use >= 0 instead of != -1 (from PR #10048) |
| tools/slang-unit-test/unit-test-cooperative-type-metadata.cpp | Comprehensive tests for the cooperative matrix metadata API |
| tools/slang-unit-test/unit-test-list.cpp | Unit tests for the updated binarySearch behavior |
|
Some context for this PR in #10021 (comment). |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@source/compiler-core/slang-artifact-associated-impl.cpp`:
- Around line 283-293: ArtifactPostEmitMetadata::getInterface is returning the
wrong interface pointer for the slang::IMetadata GUID (it returns
IArtifactPostEmitMetadata*), violating COM rules; update the GUID checks so that
when guid == slang::IMetadata::getTypeGuid() you return
static_cast<slang::IMetadata*>(this), when guid ==
IArtifactPostEmitMetadata::getTypeGuid() you return
static_cast<IArtifactPostEmitMetadata*>(this), and keep the ISlangUnknown,
ICastable and slang::ICooperativeMatrixMetadata cases returning their
corresponding static_casts so each GUID returns the correctly typed interface
pointer.
In `@source/slang/slang-check-decl.cpp`:
- Around line 3581-3582: The two indexOf() results (from
ancestor->getMembers().indexOf(subAncestor) and indexOf(supAncestor)) must be
validated for -1 before being used in arithmetic; update the code to check for
-1 and either assert (e.g., SLANG_ASSERT) or apply the same fallback/ordering
logic as used in _compareDeclsInCommonParentByOrderOfDeclaration() so a missing
member does not produce incorrect ordering. Locate the variables subIndex and
supIndex in this block and add the check/early-return or normalized ordering
fallback when either index is -1, and log or assert with clear context
referencing subAncestor/supAncestor to aid debugging.
86eccf9 to
5878845
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tools/slang-unit-test/unit-test-cooperative-type-metadata.cpp`:
- Around line 37-50: The test must guard and skip when SPIR-V or the
spvCooperativeMatrixKHR capability is unavailable; before proceeding with
targetDesc/compilation use the project's test-skip pattern (SLANG_IGNORE_TEST)
to check that globalSession->findProfile("spirv_1_6") is non-null and that
int32_t(globalSession->findCapability("spvCooperativeMatrixKHR")) is non-zero
(or otherwise indicates support), and if either check fails call
SLANG_IGNORE_TEST with an explanatory message instead of continuing; update the
logic around globalSession, targetDesc.format (SLANG_SPIRV), and
capabilityOption so the test is skipped early when those features are absent.
5878845 to
187b4bb
Compare
jkwak-work
left a comment
There was a problem hiding this comment.
This PR wouldn't provide the information needed to query the combination of types that the graphics driver supports.
187b4bb to
9f61c6b
Compare
9f61c6b to
47bed14
Compare
|
@coderabbitai summary |
✅ Actions performedSummary regeneration triggered. |
47bed14 to
12fda08
Compare
csyonghe
left a comment
There was a problem hiding this comment.
In general, I think we need to refactor the way coopMatMulAdd and other coop-vec intrinsics are implemented in the compiler first. We need to move the lowering logic from hlsl.meta.slang (library code) into each backends first so we can map each instrinsics to its dedicated slang IR opcode to simplify the meta data collection logic.
|
@cmarcelo , were you able to make some progress on this PR? |
Working on the suggestions. Will update PR here in the next few days. |
01070ec to
a00945c
Compare
8d6a9d6 to
ae3fd3c
Compare
40f4137 to
652560c
Compare
|
@jkwak-work ping |
|
Sorry for the delay. I was OOTO for a week. |
Expose cooperative matrix and cooperative vector metadata through a new `ICooperativeTypesMetadata` interface queried from `IMetadata`. This information can be used to query the combinations of types that the graphics driver supports. Closes shader-slang#10021.
| enum SlangScope : SlangScopeIntegral | ||
| { | ||
| SLANG_SCOPE_NONE, | ||
| SLANG_SCOPE_THREAD, | ||
| SLANG_SCOPE_WAVE, | ||
| SLANG_SCOPE_THREAD_GROUP, | ||
| }; |
There was a problem hiding this comment.
Is this same to the one we already have?
slang/source/slang/slang-type-system-shared.h
Lines 119 to 130 in e4d533d
LLMs are also point problems on this enum.
I think the enum name should be more like MemoryScope rather than just Scope.
And the values should match to the internal and SPIRV values.
| // Runs after target-specific lowering so it only captures cooperative types that remain | ||
| // as native constructs visible to the driver (see ICooperativeTypesMetadata docs). | ||
| SLANG_PASS(collectCooperativeMetadata, *metadata); |
There was a problem hiding this comment.
I think we can skip this pass when the cooperative capability is not used or cooperative-vector or cooperative-matrix types are not found earlier.
| List<slang::CooperativeMatrixType> m_cooperativeMatrixTypes; | ||
| List<slang::CooperativeMatrixCombination> m_cooperativeMatrixCombinations; | ||
| List<slang::CooperativeVectorType> m_cooperativeVectorTypes; | ||
| List<slang::CooperativeVectorCombination> m_cooperativeVectorCombinations; |
There was a problem hiding this comment.
I am not sure if List is a right container for this when you want to maintain the sorted order.
Although the name List sounds like a linked-list but it is more like std::vector.
The cost of inserting in the middle will be higher than you may expected.
There was a problem hiding this comment.
Verdict: 🔴 Has issues — 1 bug, 3 gaps, 1 question
This PR adds ICooperativeTypesMetadata, a new COM interface for querying cooperative matrix/vector types and multiply-add combinations from compiled shader metadata. The implementation collects metadata from IR after target-specific lowering via a new collectCooperativeMetadata pass. The main correctness concern is the use of unchecked cast<> in the metadata collection which can crash in release builds if IR invariants don't hold; the remaining findings are about namespace hygiene, documentation, and test coverage for key internal invariants.
Changes Overview
Public API (include/slang.h)
- What changed: Added
SlangScopeenum, four new structs (CooperativeMatrixType,CooperativeMatrixCombination,CooperativeVectorType,CooperativeVectorCombination), and a newICooperativeTypesMetadataCOM interface with count+getByIndex methods for each type/combination list.
Metadata implementation (source/slang/slang-ir-metadata.cpp, slang-ir-metadata.h)
- What changed: Added
collectCooperativeMetadatapass that walks the IR breadth-first after target-specific lowering, extracts cooperative matrix/vector types and multiply-add combinations, and stores them in sorted deduplicated lists onArtifactPostEmitMetadata.
Metadata storage (source/compiler-core/slang-artifact-associated-impl.cpp, .h)
- What changed:
ArtifactPostEmitMetadatanow inheritsICooperativeTypesMetadata, stores fourList<>members for cooperative types/combinations, and implements the eight get-count/get-by-index methods with null-pointer and bounds validation. COMgetInterfaceupdated for the new GUID.
Pipeline integration (source/slang/slang-emit.cpp)
- What changed:
collectCooperativeMetadatainvoked after target-specific lowering inlinkAndOptimizeIR, so only backends that retain native cooperative types (SPIRV, HLSL, CUDA+OptiX) report non-empty metadata.
Tests (tools/slang-unit-test/unit-test-cooperative-type-metadata.cpp)
- What changed: New 861-line unit test file covering subgroup/workgroup matrix metadata, vector type/combination metadata, training operations, mixed training+non-training, lowered targets (plain CUDA), and empty shaders across SPIRV/HLSL/CUDA targets.
Findings (5 total)
| Severity | Location | Finding |
|---|---|---|
| 🔴 Bug | slang-ir-metadata.cpp:469 |
cast<> assumes prior validation; crashes in release on malformed IR |
| 🟡 Gap | slang-ir-metadata.cpp:13 |
operator< in slang namespace for public types risks ODR violations |
| 🟡 Gap | include/slang.h:860 |
New SlangScope enum lacks documentation |
| 🟡 Gap | unit-test-cooperative-type-metadata.cpp:831 |
No tests for deduplication or maxSize merge logic |
| 🔵 Question | include/slang.h:4509 |
.use defaults to USE_A rather than a sentinel value |
| IRIntegerValue packingFactorValue = | ||
| cast<IRIntLit>(matMulAdd->getInputInterpretationPackingFactor())->getValue(); |
There was a problem hiding this comment.
🔴 Bug: Unsafe cast<> usage assumes prior validation holds unconditionally
Three call sites in the cooperative metadata collection use cast<IRIntLit> / cast<IRBoolLit> (here, line 483 for getTranspose(), and line 559 for getSaturatingAccumulation()) based on the comment that validateCooperativeOperations guarantees operand literal types. However, cast<> only asserts in debug builds — in release builds a null or wrong-type operand would silently produce a null pointer and immediately dereference it, causing a crash.
The assumption that validation always precedes this pass is fragile: any future IR pass inserted between validateCooperativeOperations and collectCooperativeMetadata could create or modify cooperative instructions without re-validation. The rest of this file (e.g., _getCooperativeMatrixType at lines 309-341) already uses the defensive as<IRIntLit> + null-check pattern consistently.
Example: If a future optimization pass between validation and metadata collection performs constant folding that replaces an IRIntLit with a different IR node type, cast<IRIntLit>(nullptr) returns null in release builds, and ->getValue() dereferences it.
Suggested fix:
auto packingLit = as<IRIntLit>(matMulAdd->getInputInterpretationPackingFactor());
if (!packingLit)
return;
IRIntegerValue packingFactorValue = packingLit->getValue();Apply the same pattern to the cast<IRBoolLit> calls for getTranspose() and getSaturatingAccumulation().
| // Define operator< for public cooperative type structs, used internally. | ||
| namespace slang | ||
| { | ||
| bool operator<(const CooperativeMatrixType& a, const CooperativeMatrixType& b) |
There was a problem hiding this comment.
🟡 Gap: operator< defined in slang namespace for public API types
These operator< overloads are defined in the slang namespace for structs declared in the public header include/slang.h. While they are only used internally for sorted insertion within this .cpp file, defining them at namespace scope means they participate in ADL for these types. If any other translation unit (or a downstream consumer) defines a different operator< for the same types in namespace slang, it would be an ODR violation.
Suggestion: Move the comparison logic into an anonymous namespace using named comparator structs or lambdas, and pass them explicitly to the _insertSortedUnique/lowerBound calls. This keeps the ordering internal to this file without polluting the slang namespace:
namespace {
struct CooperativeMatrixTypeLess {
bool operator()(const slang::CooperativeMatrixType& a,
const slang::CooperativeMatrixType& b) const { /* ... */ }
};
} // anonymous namespace| typedef SlangUInt32 SlangScopeIntegral; | ||
| enum SlangScope : SlangScopeIntegral | ||
| { | ||
| SLANG_SCOPE_NONE, | ||
| SLANG_SCOPE_THREAD, | ||
| SLANG_SCOPE_WAVE, |
There was a problem hiding this comment.
🟡 Gap: New public enum SlangScope lacks documentation
This is a new public API enum that maps GPU execution scopes for cooperative operations. Unlike the struct fields and the ICooperativeTypesMetadata interface (which have doc comments), the enum values have no documentation. Users need to understand the mapping to GPU concepts (thread = invocation, wave = subgroup, thread group = workgroup) and how these correspond to Vulkan/HLSL scopes.
Suggestion: Add a brief doc comment and per-value annotations:
/** Cooperative type execution scope.
* Specifies the scope of threads that cooperate in matrix/vector operations.
*/
typedef SlangUInt32 SlangScopeIntegral;
enum SlangScope : SlangScopeIntegral
{
SLANG_SCOPE_NONE, ///< Invalid/unspecified scope
SLANG_SCOPE_THREAD, ///< Single invocation
SLANG_SCOPE_WAVE, ///< Subgroup/wave scope
SLANG_SCOPE_THREAD_GROUP, ///< Workgroup/thread group scope
};| } | ||
| } | ||
|
|
||
| SLANG_UNIT_TEST(cooperativeMetadataEmptyShader) |
There was a problem hiding this comment.
🟡 Gap: No test for deduplication or cooperative vector maxSize merging
The implementation has two key invariants that are not directly validated by any test:
-
Sorted unique insertion (
_insertSortedUnique): If the same cooperative matrix type appears in multiple IR instructions, the metadata should contain it exactly once. The existing tests implicitly exercise this (types shared across matmul operands), but no test explicitly verifies that duplicate types are deduplicated — if the sorted-insertion logic regresses, the existing tests wouldn't necessarily catch it because they only check expected counts. -
maxSizemerging (_insertOrUpdateCooperativeVectorType): When the same component type appears with different vector sizes (e.g.,CoopVec<4, float>andCoopVec<8, float>), the implementation merges them into one entry withmaxSize = max(4, 8) = 8. No test validates this merge behavior. Similarly, theusedForTrainingOpOR-merging when a type appears in both training and non-training contexts is only tested with the mixed test — but that test doesn't verify themaxSizeis the maximum across both uses.
Suggestion: Add a targeted test that uses the same cooperative matrix type in multiple operations and asserts the type count is exactly 1 (not 2). Add a second test that uses the same vector component type with sizes 4, 8, 16 and asserts maxSize == 16.
| uint32_t rowCount = 0; | ||
| uint32_t columnCount = 0; | ||
|
|
||
| SlangCooperativeMatrixUse use = SLANG_COOPERATIVE_MATRIX_USE_A; |
There was a problem hiding this comment.
🔵 Question: Default .use = SLANG_COOPERATIVE_MATRIX_USE_A for invalid types
All other fields in CooperativeMatrixType default to zero/NONE to indicate "not valid," but .use defaults to SLANG_COOPERATIVE_MATRIX_USE_A (value 0) because the SlangCooperativeMatrixUse enum has no NONE sentinel. This means a default-constructed or partially-initialized struct has a field that looks valid (USE_A) even when the struct is logically invalid (componentType == NONE).
The doc comment says componentType == NONE is the validity check, so this is consistent — but it could confuse callers who inspect .use without first checking componentType. Is this intentional, or would it be worth adding a SLANG_COOPERATIVE_MATRIX_USE_NONE sentinel to the existing enum (recognizing that changes the enum layout)?
|
@cmarcelo , can you address some of my concerns? |
Expose cooperative matrix and cooperative vector metadata through
a new
ICooperativeTypesMetadatainterface queried fromIMetadata.This information can be used to query the combinations of types that
the graphics driver supports.
Closes #10021.